from __future__ import print_function
import torch
import argparse

from utils import load_data
from utils import landmark_o_crop

from network_and_loss import CoordNet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

parser = argparse.ArgumentParser(description='Prior modeling: CoordNet')

# Hyperparameters 
parser.add_argument('--task', type=str, default="COFW", help='Task dataset')
parser.add_argument('--random_scale', default=False, help='Whether to apply random flip')
parser.add_argument('--random_flip', default=False, help='Whether to apply random flip')
parser.add_argument('--random_rotation', default=False, help='Whether to apply random rotation')

parser.add_argument('--batch_size', type=int, default=64, help='Batch size')

args = parser.parse_args()

def main():
    torch.cuda.empty_cache()
    train_loader, _ = load_data(args.task, args.batch_size, args.random_scale, args.random_flip, args.random_rotation)
    
    coordnet = CoordNet().to(device)
    
    c = torch.load("../../Checkpoint/CoordNet_2000epoch.pth")
    coordnet.load_state_dict(c['coordnet'])
    
    prior_modeling(coordnet, train_loader)
    
    
    
def prior_modeling(coordnet, train_loader) :
    coordnet.eval()
    
    landmark_z_stack = torch.zeros(29, 1, 128).to(device)
    
    with torch.no_grad(): 
        for i, (images, landmark_coords) in enumerate(train_loader):
            images, landmark_coords = images.to(device), landmark_coords.to(device)
            landmark_coords = landmark_coords.view(-1, 29, 2)
            landmark_o = landmark_o_crop(images, landmark_coords)
            landmark_z, _ = coordnet(landmark_o.view(-1, 2, 27, 27))
            
            landmark_z_stack = torch.cat((landmark_z_stack, landmark_z.view(29, -1, 128)), dim=1)
            
    landmark_z_mean = landmark_z_stack[:, 1:].mean(1)
    
    state = {
        "landmark_z_mean": landmark_z_mean
        }
    
    torch.save(state, "../../Checkpoint/CoordNet_prior.pth")
    
    

if __name__=='__main__':
    main()
